# err_df = pd.DataFrame()
for method_dir in os.listdir("../data"):
train_acc = pd.read_csv(os.path.join("../data", method_dir, "train_accs.txt"), header=None)
train_loss = pd.read_csv(os.path.join("../data", method_dir, "train_losses.txt"), header=None)
val_acc = pd.read_csv(os.path.join("../data", method_dir, "val_accs.txt"), header=None)
val_loss = pd.read_csv(os.path.join("../data", method_dir, "val_losses.txt"), header=None)
# train_acc = train_acc.rename(columns={0: f"{method_dir.lower()}__train_acc"})
# train_loss = train_loss.rename(columns={0: f"{method_dir.lower()}__train_loss"})
# val_acc = val_acc.rename(columns={0: f"{method_dir.lower()}__val_acc"})
# val_loss = val_loss.rename(columns={0: f"{method_dir.lower()}__val_loss"})
# err_df = pd.concat([
# err_df,
# pd.concat([train_acc, train_loss, val_acc, val_loss], axis=1)
# ], axis=1)
train_acc = train_acc.rename(columns={0: "train_accuracy"})
train_loss = train_loss.rename(columns={0: f"train_loss"})
val_acc = val_acc.rename(columns={0: f"validation_accuracy"})
val_loss = val_loss.rename(columns={0: f"validation_loss"})
err_df = pd.concat([train_acc, train_loss, val_acc, val_loss], axis=1)
epochs = err_df.shape[0]
plot_title = method_dir
fig1 = go.Figure()
fig2 = go.Figure()
for index, col in enumerate(err_df[["train_accuracy", "validation_accuracy"]].columns):
fig1.add_trace(
go.Scatter(
x=err_df.index,
y=err_df[col],
mode="lines",
name=f"{col}",
line=dict(color=colors[index % len(colors)])
)
)
for index, col in enumerate(err_df[["train_loss", "validation_loss"]].columns):
fig2.add_trace(
go.Scatter(
x=err_df.index,
y=err_df[col],
mode="lines",
name=f"{col}",
line=dict(color=colors[index % len(colors) + 2])
)
)
fig1.update_layout(
template="plotly_white",
legend_title="<b>Measure: </b>",
title_text=plot_title,
xaxis_title="Epoch",
yaxis_title="Accuracy",
width=800,
height=700,
)
fig2.update_layout(
template="plotly_white",
legend_title="<b>Measure: </b>",
title_text=plot_title,
xaxis_title="Epoch",
yaxis_title="Loss",
width=800,
height=700,
)
fig1.show()
fig2.show()